Skip to content

Add option for selective op AC to filter mm shapes based on fqn #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

soulitzer
Copy link

@soulitzer soulitzer commented Jul 11, 2025

Also see discussion in #1372

This PR:

  • Adds new config for SAC with the default such that per-op SAC automatically skips all mms with args[1].shape matching that of the Linear at fqn "moe.router.gate"
  • Adds general flop/act-mem/correctness tests for AC as well as the new config

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 11, 2025
@soulitzer soulitzer force-pushed the soulitzer/add-sac-psuedo-fqn-policy branch 4 times, most recently from 9e3b49b to 3c4d97d Compare July 11, 2025 14:38
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great idea!
Since this feature is advanced, could you also help test if the behavior is expected?

It seems this feature does not require distributed, so maybe we can add a unit test file in
https://github.com/pytorch/torchtitan/tree/main/tests/unit_tests

But if it doesn't make sense, feel free to do it in the way you prefer.

@@ -487,6 +487,20 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(
default_factory=lambda: []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good enough to

Suggested change
default_factory=lambda: []
default_factory=list

or we can default to ["moe.router.gate"] so that we don't need to define it in a lot of tomls.

O/w could you please also update the tomls in https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/llama4/train_configs
and
https://github.com/pytorch/torchtitan/tree/main/torchtitan/models/deepseek_v3/train_configs

Comment on lines 267 to 277
if (
fqn
not in ac_config.selective_op_ac_force_recompute_mm_shapes_by_fqns
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in float8, we also filter by fqns, in which we are doing reversely
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/quantization/utils.py#L25
I think one reason could be that the filter over there is applied to the whole model, so one fqn can help map to multiple layers / modules.

I think for AC there's not that much difference between the two. The benefit of doing it the other way may be users don't need to specify accurately the full relative fqn within the AC region. E.g. "router.gate" would also work.

I don't have a strong preference, but maybe let's be consistent with float8 if you don't have strong preference either.

@@ -487,6 +487,20 @@ class ActivationCheckpoint:
'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
"""

selective_op_ac_force_recompute_mm_shapes_by_fqns: list[str] = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should prefer a shorter name over how accurate its meaning is. How about per_op_sac_filter_fqns? Most users shouldn't really care about the details of implementation; if some users do, they can check the helper message and implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. To help reduce the cognitive load of parsing the config file for the average user who I agree won't care about the impl, would it help if the default already include any moe router fqns for TorchTitan models per your other suggestion? This means most configs won't need to contain it at all, so most users won't see it and the advanced users using it will still benefit from a more explicit name.

I think this is consistent with most users already not being aware of what the per-op sac policy is at all, although we could potentially refactor things such that we have specific policies like policy="compute_intensive_excluding_every_other_matmul"

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I left some comments.

The unit test looks awesome, although IIRC our CPU unit test can't handle GPU tests
https://github.com/pytorch/torchtitan/blob/main/.github/workflows/unit_test_cpu.yaml

Do you think we can test AC on a CPU? If not, we can land the current one for now, and I'll try to find a way to run GPU unit tests later.

@@ -27,7 +27,7 @@
SequenceParallel,
)

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.config_manager import ActivationCheckpoint, JobConfig, TORCH_DTYPE_MAP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe ActivationCheckpoint as ACConfig


def test_correctness(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is unavailable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe noob question:
Does AC require GPU to run? My intuition was it should be able to run on CPU.

class TestApplyAC(unittest.TestCase):
def test_flops(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is unavailable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar question: Does AC / FlopCounterMode require GPU to run?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AC and FlopCounterMode should not require GPU, but peak memory stats does. I can refactor out the flop counter test so that it runs if we are only able to run CPU-only.

@@ -237,7 +237,9 @@ def apply_tp(
}


def _apply_ac_to_transformer_block(module: nn.Module, ac_config):
def _apply_ac_to_transformer_block(
module: nn.Module, ac_config: ActivationCheckpoint, base_fqn: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since in torchtitan we only apply AC at transformer block level, I feel the arg base_fqn is less needed, in the sense that there should be rare cases where user apply per op SAC, but only wants to filter router.gate matmul in layer 1 but not layer 2.

Most use cases would be per_op_sac_force_recompute_mm_shapes_by_fqns = ["moe.router.gate"] and moe.router.gate should be already in module_fqn without base_fqn.

If that's the case, I think it's not necessary to add this field. Let me know if you think otherwise.

Copy link
Author

@soulitzer soulitzer Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary but I wanted to be consistent with float8's fqn matching if that is based on the entire model's fqn. We can also avoid a line of documentation mentioning that fqns are actually relative to TransformerBlock. Let me know if you'd still like it changed, I think either is fine, but slightly preferred this direciton.

@soulitzer soulitzer force-pushed the soulitzer/add-sac-psuedo-fqn-policy branch from bfb3a32 to c2cdb20 Compare July 15, 2025 13:46
@soulitzer soulitzer force-pushed the soulitzer/add-sac-psuedo-fqn-policy branch from c2cdb20 to b27dae7 Compare July 15, 2025 13:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants